#! /usr/bin/env python2

import logging
import os
import multiprocessing
import subprocess
import sys

from lab.experiment import get_run_dir
from lab import tools

tools.configure_logging()

SHUFFLED_TASK_IDS = [1, 2]

# Make sure we're in the experiment directory.
os.chdir(os.path.dirname(os.path.abspath(__file__)))


def get_run_id(task_id):
    return SHUFFLED_TASK_IDS[task_id - 1]


def process_task(task_id):
    run_id = get_run_id(task_id)
    run_dir = get_run_dir(run_id)
    error = False
    with open(os.path.join(run_dir, 'driver.log'), 'w') as driver_log:
        with open(os.path.join(run_dir, 'driver.err'), 'w') as driver_err:
            logging.info('Starting run {run_id} (TASK_ID {task_id}) in {run_dir}'.format(**locals()))
            try:
                subprocess.check_call(['./run'], cwd=run_dir, stdout=driver_log, stderr=driver_err)
            except subprocess.CalledProcessError as err:
                error = True
    if os.path.getsize(driver_err.name) != 0:
        error = True
    for f in [driver_log, driver_err]:
        if os.path.getsize(f.name) == 0:
            os.remove(f.name)
    return error


def main():
    pool = multiprocessing.Pool(processes=4)
    num_tasks = len(SHUFFLED_TASK_IDS)
    result = pool.map_async(process_task, range(1, num_tasks + 1))
    try:
        # Use "timeout" to fix passing KeyboardInterrupts from children
        # (see https://stackoverflow.com/questions/1408356).
        result.wait(timeout=sys.maxint)
    except KeyboardInterrupt:
        logging.warning('Main script interrupted')
        pool.terminate()
    finally:
        pool.close()
        logging.info('Joining pool processes')
        pool.join()

    if any(result.get()):
        sys.exit("Error: At least one run failed.")


if __name__ == '__main__':
    main()
